import os
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from torch.utils.data.distributed import DistributedSampler
import torch.distributed as dist
from itertools import chain
from datasets import load_dataset
from transformers import AutoTokenizer
import torch


def wikitext103(config):
    # config is only for data,
    gpu_id = int(os.getenv("RANK", -1))
    if gpu_id not in [-1, 0]:
        dist.barrier()  # barrier here to download, tokenize, and group first

    raw_datasets = load_dataset(
        "wikitext", "wikitext-103-raw-v1", cache_dir=config.train.path
    )
    tokenizer = AutoTokenizer.from_pretrained(
        config.tokenizer.name, cache_dir=config.train.path, use_fast_tokenizer=True
    )

    def tokenize_function(examples):
        output = tokenizer(examples["text"])
        return output

    tokenized_datasets = raw_datasets.map(
        tokenize_function,
        batched=True,
        num_proc=config.num_workers,
        remove_columns="text",
        load_from_cache_file=not config.overwrite_cache,
        desc="Running tokenizer on dataset",
    )
    block_size = min(config.block_size, config.tokenizer.model_max_length)

    # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
    def group_texts(examples):
        # Concatenate all texts.
        concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
        total_length = len(concatenated_examples[list(examples.keys())[0]])
        # We drop the small remainder, and if the total_length < block_size  we exclude this batch and return an empty dict.
        # We could add padding if the model supported it instead of this drop, you can customize this part to your needs.
        total_length = (total_length // block_size) * block_size
        # Split by chunks of max_len.
        result = {
            k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
            for k, t in concatenated_examples.items()
        }
        result["labels"] = result["input_ids"].copy()
        return result

    lm_datasets = tokenized_datasets.map(
        group_texts,
        batched=True,
        num_proc=config.num_workers,
        load_from_cache_file=not config.overwrite_cache,
        desc=f"Grouping texts in chunks of {block_size}",
    )

    if gpu_id == 0:
        dist.barrier()  # barrier here to download and tokenize first

    train_dataset = lm_datasets["train"]
    eval_dataset = lm_datasets["validation"]

    train_dataset = TensorDataset(
        torch.tensor(train_dataset["input_ids"], dtype=torch.long),
        torch.tensor(train_dataset["attention_mask"], dtype=torch.long),
        torch.tensor(train_dataset["labels"], dtype=torch.long),
    )
    eval_dataset = TensorDataset(
        torch.tensor(eval_dataset["input_ids"], dtype=torch.long),
        torch.tensor(eval_dataset["attention_mask"], dtype=torch.long),
        torch.tensor(eval_dataset["labels"], dtype=torch.long),
    )

    train_sampler = (
        RandomSampler(train_dataset)
        if gpu_id == -1
        else DistributedSampler(train_dataset, seed=torch.random.initial_seed())
    )
    train_loader = DataLoader(
        train_dataset,
        sampler=train_sampler,
        num_workers=config.num_workers,
        batch_size=config.train.train_batch,
    )

    eval_sampler = SequentialSampler(eval_dataset)
    eval_loader = DataLoader(
        eval_dataset,
        sampler=eval_sampler,
        num_workers=config.num_workers,
        batch_size=config.test.test_batch,
    )

    return train_loader, eval_loader
